# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import os
import warnings
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

activation_table = {"relu": nn.ReLU(), "silu": nn.SiLU(), "hardswish": nn.Hardswish()}


class SiLU(nn.Module):
    """Activation of SiLU"""

    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)


class ConvModule(nn.Module):
    """A combination of Conv + BN + Activation"""

    def __init__(
        self, in_channels, out_channels, kernel_size, stride, activation_type, padding=None, groups=1, bias=False
    ):
        super().__init__()
        if padding is None:
            padding = kernel_size // 2
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
            bias=bias,
        )
        self.bn = nn.BatchNorm2d(out_channels)
        if activation_type is not None:
            self.act = activation_table.get(activation_type)
        self.activation_type = activation_type

    def forward(self, x):
        if self.activation_type is None:
            return self.bn(self.conv(x))
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        if self.activation_type is None:
            return self.conv(x)
        return self.act(self.conv(x))


class ConvBNReLU(nn.Module):
    """Conv and BN with ReLU activation"""

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, groups=1, bias=False):
        super().__init__()
        self.block = ConvModule(in_channels, out_channels, kernel_size, stride, "relu", padding, groups, bias)

    def forward(self, x):
        return self.block(x)


class ConvBNSiLU(nn.Module):
    """Conv and BN with SiLU activation"""

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, groups=1, bias=False):
        super().__init__()
        self.block = ConvModule(in_channels, out_channels, kernel_size, stride, "silu", padding, groups, bias)

    def forward(self, x):
        return self.block(x)


class SPPFModule(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, block=ConvBNReLU):
        super().__init__()
        c_ = in_channels // 2  # hidden channels
        self.cv1 = block(in_channels, c_, 1, 1)
        self.cv2 = block(c_ * 4, out_channels, 1, 1)
        self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)

    def forward(self, x):
        x = self.cv1(x)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            y1 = self.m(x)
            y2 = self.m(y1)
            return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))


class SPPF(nn.Module):
    """SPPF with SiLU activation"""

    def __init__(self, in_channels, out_channels, kernel_size=5, block=ConvBNSiLU):
        super().__init__()
        self.sppf = SPPFModule(in_channels, out_channels, kernel_size, block)

    def forward(self, x):
        return self.sppf(x)


class CSPSPPFModule(nn.Module):
    # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
    def __init__(self, in_channels, out_channels, kernel_size=5, e=0.5, block=ConvBNReLU):
        super().__init__()
        c_ = int(out_channels * e)  # hidden channels
        self.cv1 = block(in_channels, c_, 1, 1)
        self.cv2 = block(in_channels, c_, 1, 1)
        self.cv3 = block(c_, c_, 3, 1)
        self.cv4 = block(c_, c_, 1, 1)

        self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
        self.cv5 = block(4 * c_, c_, 1, 1)
        self.cv6 = block(c_, c_, 3, 1)
        self.cv7 = block(2 * c_, out_channels, 1, 1)

    def forward(self, x):
        x1 = self.cv4(self.cv3(self.cv1(x)))
        y0 = self.cv2(x)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            y1 = self.m(x1)
            y2 = self.m(y1)
            y3 = self.cv6(self.cv5(torch.cat([x1, y1, y2, self.m(y2)], 1)))
        return self.cv7(torch.cat((y0, y3), dim=1))


class Transpose(nn.Module):
    """Normal Transpose, default for upsampling"""

    def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
        super().__init__()
        self.upsample_transpose = torch.nn.ConvTranspose2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, bias=True
        )

    def forward(self, x):
        return self.upsample_transpose(x)


class RepVGGBlock(nn.Module):
    """RepVGGBlock is a basic rep-style block, including training and deploy status
    This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        padding=1,
        dilation=1,
        groups=1,
        padding_mode="zeros",
        deploy=False,
        use_se=False,
    ):
        super(RepVGGBlock, self).__init__()
        """ Initialization of the class.
        Args:
            in_channels (int): Number of channels in the input image
            out_channels (int): Number of channels produced by the convolution
            kernel_size (int or tuple): Size of the convolving kernel
            stride (int or tuple, optional): Stride of the convolution. Default: 1
            padding (int or tuple, optional): Zero-padding added to both sides of
                the input. Default: 1
            dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
            groups (int, optional): Number of blocked connections from input
                channels to output channels. Default: 1
            padding_mode (string, optional): Default: 'zeros'
            deploy: Whether to be deploy status or training status. Default: False
            use_se: Whether to use se. Default: False
        """
        self.deploy = deploy
        self.groups = groups
        self.in_channels = in_channels
        self.out_channels = out_channels

        assert kernel_size == 3
        assert padding == 1

        padding_11 = padding - kernel_size // 2

        self.nonlinearity = nn.ReLU()

        if use_se:
            raise NotImplementedError("se block not supported yet")
        else:
            self.se = nn.Identity()

        if deploy:
            self.rbr_reparam = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=groups,
                bias=True,
                padding_mode=padding_mode,
            )

        else:
            self.rbr_identity = (
                nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
            )
            self.rbr_dense = ConvModule(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                activation_type=None,
                padding=padding,
                groups=groups,
            )
            self.rbr_1x1 = ConvModule(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                stride=stride,
                activation_type=None,
                padding=padding_11,
                groups=groups,
            )

    def forward(self, inputs):
        """Forward process"""
        if hasattr(self, "rbr_reparam"):
            return self.nonlinearity(self.se(self.rbr_reparam(inputs)))

        if self.rbr_identity is None:
            id_out = 0
        else:
            id_out = self.rbr_identity(inputs)

        return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))

    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

    def _avg_to_3x3_tensor(self, avgp):
        channels = self.in_channels
        groups = self.groups
        kernel_size = avgp.kernel_size
        input_dim = channels // groups
        k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
        k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size**2
        return k

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])

    def _fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch, ConvModule):
            kernel = branch.conv.weight
            bias = branch.conv.bias
            return kernel, bias
        elif isinstance(branch, nn.BatchNorm2d):
            if not hasattr(self, "id_tensor"):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1
                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
            kernel = self.id_tensor
            running_mean = branch.running_mean
            running_var = branch.running_var
            gamma = branch.weight
            beta = branch.bias
            eps = branch.eps
            std = (running_var + eps).sqrt()
            t = (gamma / std).reshape(-1, 1, 1, 1)
            return kernel * t, beta - running_mean * gamma / std

    def switch_to_deploy(self):
        if hasattr(self, "rbr_reparam"):
            return
        kernel, bias = self.get_equivalent_kernel_bias()
        self.rbr_reparam = nn.Conv2d(
            in_channels=self.rbr_dense.conv.in_channels,
            out_channels=self.rbr_dense.conv.out_channels,
            kernel_size=self.rbr_dense.conv.kernel_size,
            stride=self.rbr_dense.conv.stride,
            padding=self.rbr_dense.conv.padding,
            dilation=self.rbr_dense.conv.dilation,
            groups=self.rbr_dense.conv.groups,
            bias=True,
        )
        self.rbr_reparam.weight.data = kernel
        self.rbr_reparam.bias.data = bias
        for para in self.parameters():
            para.detach_()
        self.__delattr__("rbr_dense")
        self.__delattr__("rbr_1x1")
        if hasattr(self, "rbr_identity"):
            self.__delattr__("rbr_identity")
        if hasattr(self, "id_tensor"):
            self.__delattr__("id_tensor")
        self.deploy = True


class DetectBackend(nn.Module):
    def __init__(self, weights="yolov6s.pt", device=None, dnn=True):
        super().__init__()
        if not os.path.exists(weights):
            download_ckpt(weights)  # try to download model from github automatically.
        assert (
            isinstance(weights, str) and Path(weights).suffix == ".pt"
        ), f"{Path(weights).suffix} format is not supported."
        from yolov6.utils.checkpoint import load_checkpoint

        model = load_checkpoint(weights, map_location=torch.device("cpu"))
        stride = int(model.stride.max())
        self.__dict__.update(locals())  # assign all variables to self

    def forward(self, im, val=False):
        y, _ = self.model(im)
        if isinstance(y, np.ndarray):
            y = torch.tensor(y, device=self.device)

        return y


class RepBlock(nn.Module):
    """
    RepBlock is a stage block with rep-style basic block
    """

    def __init__(self, in_channels, out_channels, n=1, block=RepVGGBlock, basic_block=RepVGGBlock):
        super().__init__()

        self.conv1 = block(in_channels, out_channels)
        self.block = nn.Sequential(*(block(out_channels, out_channels) for _ in range(n - 1))) if n > 1 else None
        if block == BottleRep:
            self.conv1 = BottleRep(in_channels, out_channels, basic_block=basic_block, weight=True)
            n = n // 2
            self.block = (
                nn.Sequential(
                    *(BottleRep(out_channels, out_channels, basic_block=basic_block, weight=True) for _ in range(n - 1))
                )
                if n > 1
                else None
            )

    def forward(self, x):
        x = self.conv1(x)
        if self.block is not None:
            x = self.block(x)
        return x


class BottleRep(nn.Module):
    def __init__(self, in_channels, out_channels, basic_block=RepVGGBlock, weight=False):
        super().__init__()
        self.conv1 = basic_block(in_channels, out_channels)
        self.conv2 = basic_block(out_channels, out_channels)
        if in_channels != out_channels:
            self.shortcut = False
        else:
            self.shortcut = True
        if weight:
            self.alpha = Parameter(torch.ones(1))
        else:
            self.alpha = 1.0

    def forward(self, x):
        outputs = self.conv1(x)
        outputs = self.conv2(outputs)
        return outputs + self.alpha * x if self.shortcut else outputs


class BottleRep3(nn.Module):
    def __init__(self, in_channels, out_channels, basic_block=RepVGGBlock, weight=False):
        super().__init__()
        self.conv1 = basic_block(in_channels, out_channels)
        self.conv2 = basic_block(out_channels, out_channels)
        self.conv3 = basic_block(out_channels, out_channels)
        if in_channels != out_channels:
            self.shortcut = False
        else:
            self.shortcut = True
        if weight:
            self.alpha = Parameter(torch.ones(1))
        else:
            self.alpha = 1.0

    def forward(self, x):
        outputs = self.conv1(x)
        outputs = self.conv2(outputs)
        outputs = self.conv3(outputs)
        return outputs + self.alpha * x if self.shortcut else outputs


class BepC3(nn.Module):
    """CSPStackRep Block"""

    def __init__(self, in_channels, out_channels, n=1, e=0.5, block=RepVGGBlock):
        super().__init__()
        c_ = int(out_channels * e)  # hidden channels
        self.cv1 = ConvBNReLU(in_channels, c_, 1, 1)
        self.cv2 = ConvBNReLU(in_channels, c_, 1, 1)
        self.cv3 = ConvBNReLU(2 * c_, out_channels, 1, 1)
        if block == ConvBNSiLU:
            self.cv1 = ConvBNSiLU(in_channels, c_, 1, 1)
            self.cv2 = ConvBNSiLU(in_channels, c_, 1, 1)
            self.cv3 = ConvBNSiLU(2 * c_, out_channels, 1, 1)

        self.m = RepBlock(in_channels=c_, out_channels=c_, n=n, block=BottleRep, basic_block=block)

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))


class MBLABlock(nn.Module):
    """Multi Branch Layer Aggregation Block"""

    def __init__(self, in_channels, out_channels, n=1, e=0.5, block=RepVGGBlock):
        super().__init__()
        n = n // 2
        if n <= 0:
            n = 1

        # max add one branch
        if n == 1:
            n_list = [0, 1]
        else:
            extra_branch_steps = 1
            while extra_branch_steps * 2 < n:
                extra_branch_steps *= 2
            n_list = [0, extra_branch_steps, n]
        branch_num = len(n_list)

        c_ = int(out_channels * e)  # hidden channels
        self.c = c_
        self.cv1 = ConvModule(in_channels, branch_num * self.c, 1, 1, "relu", bias=False)
        self.cv2 = ConvModule((sum(n_list) + branch_num) * self.c, out_channels, 1, 1, "relu", bias=False)

        if block == ConvBNSiLU:
            self.cv1 = ConvModule(in_channels, branch_num * self.c, 1, 1, "silu", bias=False)
            self.cv2 = ConvModule((sum(n_list) + branch_num) * self.c, out_channels, 1, 1, "silu", bias=False)

        self.m = nn.ModuleList()
        for n_list_i in n_list[1:]:
            self.m.append(
                nn.Sequential(*(BottleRep3(self.c, self.c, basic_block=block, weight=True) for _ in range(n_list_i)))
            )

        self.split_num = tuple([self.c] * branch_num)

    def forward(self, x):
        y = list(self.cv1(x).split(self.split_num, 1))
        all_y = [y[0]]
        for m_idx, m_i in enumerate(self.m):
            all_y.append(y[m_idx + 1])
            all_y.extend(m(all_y[-1]) for m in m_i)
        return self.cv2(torch.cat(all_y, 1))


class BiFusion(nn.Module):
    """BiFusion Block in PAN"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.cv1 = ConvBNReLU(in_channels[0], out_channels, 1, 1)
        self.cv2 = ConvBNReLU(in_channels[1], out_channels, 1, 1)
        self.cv3 = ConvBNReLU(out_channels * 3, out_channels, 1, 1)

        self.upsample = Transpose(
            in_channels=out_channels,
            out_channels=out_channels,
        )
        self.downsample = ConvBNReLU(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=2)

    def forward(self, x):
        x0 = self.upsample(x[0])
        x1 = self.cv1(x[1])
        x2 = self.downsample(self.cv2(x[2]))
        return self.cv3(torch.cat((x0, x1, x2), dim=1))


def get_block(mode):
    if mode == "repvgg":
        return RepVGGBlock
    elif mode == "qarepvgg":
        return QARepVGGBlock
    elif mode == "qarepvggv2":
        return QARepVGGBlockV2
    elif mode == "hyper_search":
        return LinearAddBlock
    elif mode == "repopt":
        return RealVGGBlock
    elif mode == "conv_relu":
        return ConvBNReLU
    elif mode == "conv_silu":
        return ConvBNSiLU
    else:
        raise NotImplementedError("Undefied Repblock choice for mode {}".format(mode))


class SEBlock(nn.Module):
    def __init__(self, channel, reduction=4):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(
            in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0
        )
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(
            in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0
        )
        self.hardsigmoid = nn.Hardsigmoid()

    def forward(self, x):
        identity = x
        x = self.avg_pool(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.hardsigmoid(x)
        out = identity * x
        return out


def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups
    # reshape
    x = x.view(batchsize, groups, channels_per_group, height, width)
    x = torch.transpose(x, 1, 2).contiguous()
    # flatten
    x = x.view(batchsize, -1, height, width)

    return x


class Lite_EffiBlockS1(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, stride):
        super().__init__()
        self.conv_pw_1 = ConvBNHS(
            in_channels=in_channels // 2, out_channels=mid_channels, kernel_size=1, stride=1, padding=0, groups=1
        )
        self.conv_dw_1 = ConvBN(
            in_channels=mid_channels,
            out_channels=mid_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            groups=mid_channels,
        )
        self.se = SEBlock(mid_channels)
        self.conv_1 = ConvBNHS(
            in_channels=mid_channels, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1
        )

    def forward(self, inputs):
        x1, x2 = torch.split(inputs, split_size_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2], dim=1)
        x2 = self.conv_pw_1(x2)
        x3 = self.conv_dw_1(x2)
        x3 = self.se(x3)
        x3 = self.conv_1(x3)
        out = torch.cat([x1, x3], axis=1)
        return channel_shuffle(out, 2)


class Lite_EffiBlockS2(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, stride):
        super().__init__()
        # branch1
        self.conv_dw_1 = ConvBN(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            groups=in_channels,
        )
        self.conv_1 = ConvBNHS(
            in_channels=in_channels, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1
        )
        # branch2
        self.conv_pw_2 = ConvBNHS(
            in_channels=in_channels, out_channels=mid_channels // 2, kernel_size=1, stride=1, padding=0, groups=1
        )
        self.conv_dw_2 = ConvBN(
            in_channels=mid_channels // 2,
            out_channels=mid_channels // 2,
            kernel_size=3,
            stride=stride,
            padding=1,
            groups=mid_channels // 2,
        )
        self.se = SEBlock(mid_channels // 2)
        self.conv_2 = ConvBNHS(
            in_channels=mid_channels // 2, out_channels=out_channels // 2, kernel_size=1, stride=1, padding=0, groups=1
        )
        self.conv_dw_3 = ConvBNHS(
            in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, groups=out_channels
        )
        self.conv_pw_3 = ConvBNHS(
            in_channels=out_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1
        )

    def forward(self, inputs):
        x1 = self.conv_dw_1(inputs)
        x1 = self.conv_1(x1)
        x2 = self.conv_pw_2(inputs)
        x2 = self.conv_dw_2(x2)
        x2 = self.se(x2)
        x2 = self.conv_2(x2)
        out = torch.cat([x1, x2], axis=1)
        out = self.conv_dw_3(out)
        out = self.conv_pw_3(out)
        return out


class DPBlock(nn.Module):
    def __init__(self, in_channel=96, out_channel=96, kernel_size=3, stride=1):
        super().__init__()
        self.conv_dw_1 = nn.Conv2d(
            in_channels=in_channel,
            out_channels=out_channel,
            kernel_size=kernel_size,
            groups=out_channel,
            padding=(kernel_size - 1) // 2,
            stride=stride,
        )
        self.bn_1 = nn.BatchNorm2d(out_channel)
        self.act_1 = nn.Hardswish()
        self.conv_pw_1 = nn.Conv2d(
            in_channels=out_channel, out_channels=out_channel, kernel_size=1, groups=1, padding=0
        )
        self.bn_2 = nn.BatchNorm2d(out_channel)
        self.act_2 = nn.Hardswish()

    def forward(self, x):
        x = self.act_1(self.bn_1(self.conv_dw_1(x)))
        x = self.act_2(self.bn_2(self.conv_pw_1(x)))
        return x

    def forward_fuse(self, x):
        x = self.act_1(self.conv_dw_1(x))
        x = self.act_2(self.conv_pw_1(x))
        return x


class DarknetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, expansion=0.5):
        super().__init__()
        hidden_channels = int(out_channels * expansion)
        self.conv_1 = ConvBNHS(
            in_channels=in_channels, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0
        )
        self.conv_2 = DPBlock(in_channel=hidden_channels, out_channel=out_channels, kernel_size=kernel_size, stride=1)

    def forward(self, x):
        out = self.conv_1(x)
        out = self.conv_2(out)
        return out


class CSPBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, expand_ratio=0.5):
        super().__init__()
        mid_channels = int(out_channels * expand_ratio)
        self.conv_1 = ConvBNHS(in_channels, mid_channels, 1, 1, 0)
        self.conv_2 = ConvBNHS(in_channels, mid_channels, 1, 1, 0)
        self.conv_3 = ConvBNHS(2 * mid_channels, out_channels, 1, 1, 0)
        self.blocks = DarknetBlock(mid_channels, mid_channels, kernel_size, 1.0)

    def forward(self, x):
        x_1 = self.conv_1(x)
        x_1 = self.blocks(x_1)
        x_2 = self.conv_2(x)
        x = torch.cat((x_1, x_2), axis=1)
        x = self.conv_3(x)
        return x
